{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Odds Are Odd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from cleverhans.dataset import CIFAR10\n",
    "from cleverhans.evaluation import batch_eval\n",
    "from cleverhans.model_zoo.madry_lab_challenges.cifar10_model import make_wresnet as ResNet\n",
    "from cleverhans.utils_tf import initialize_uninitialized_global_variables\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"defense/\")\n",
    "\n",
    "from utils import do_eval, init_defense"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 200\n",
    "test_size = 1000\n",
    "\n",
    "data = CIFAR10()\n",
    "x_test, y_test = data.get_set('test')\n",
    "x_test *= 255\n",
    "\n",
    "x_test = x_test[:test_size]\n",
    "y_test = y_test[:test_size]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load the base model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "\n",
    "img_rows, img_cols, nchannels = x_test.shape[1:4]\n",
    "nb_classes = y_test.shape[1]\n",
    "\n",
    "# Define input TF placeholder\n",
    "x = tf.placeholder(tf.float32, shape=(None, img_rows, img_cols, nchannels))\n",
    "y = tf.placeholder(tf.float32, shape=(None, nb_classes))\n",
    "\n",
    "model = ResNet(scope='ResNet')\n",
    "preds = model.get_logits(x)\n",
    "\n",
    "model_name = 'naturally_trained'\n",
    "ckpt_dir = 'models/'\n",
    "ckpt = tf.train.get_checkpoint_state(os.path.join(os.path.expanduser(ckpt_dir), model_name))\n",
    "ckpt_path = ckpt.model_checkpoint_path\n",
    "saver = tf.train.Saver(var_list=dict((v.name.split('/', 1)[1].split(':')[0], v) for v in tf.global_variables()))\n",
    "saver.restore(sess, ckpt_path)\n",
    "initialize_uninitialized_global_variables(sess)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## initialize the defense \n",
    "(default parameters from https://github.com/yk/icml19_public/blob/ace61a/tensorflow_example.py)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if multi_noise = True, instantiate the defense with 9 types of noise.\n",
    "# if multi_noise = False, instantiate the defense with a single type of high-magnitude noise.\n",
    "multi_noise = False\n",
    "predictor = init_defense(sess, x, preds, batch_size, multi_noise=multi_noise)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## evaluate the model and predictor on clean data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [05:22<00:00, 40.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of base model: 0.9620\n",
      "Accuracy of full defense: 0.9950\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "_ = do_eval(sess, x, y, preds, x_test, y_test, 'clean_train_clean_eval', False, predictor=predictor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Our first attack: standard \"logit-matching\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the logits for clean data\n",
    "logits_clean = batch_eval(sess, [x], [preds], [x_test], batch_size=batch_size)[0]\n",
    "\n",
    "# pick targets. We'll keep it simple and just target the logits \n",
    "# of the first clean example, except for inputs that have the \n",
    "# same class as that example. For those, we target the logits \n",
    "# of the second clean example.\n",
    "y_cls = np.argmax(y_test, axis=-1)\n",
    "assert y_cls[0] != y_cls[1]\n",
    "assert np.argmax(logits_clean[0]) == y_cls[0]\n",
    "assert np.argmax(logits_clean[1]) == y_cls[1]\n",
    "target_logits = logits_clean.copy()\n",
    "target_logits[:] = logits_clean[0]\n",
    "target_logits[y_cls == y_cls[0]] = logits_clean[1]\n",
    "\n",
    "# simple squared loss over logits\n",
    "target_logits_ph = tf.placeholder(tf.float32, shape=(None, nb_classes))\n",
    "loss = tf.reduce_sum(tf.square(target_logits_ph - preds))\n",
    "grad = tf.gradients(loss, x)[0]\n",
    "\n",
    "n_batches = math.ceil(x_test.shape[0] / batch_size)\n",
    "X_adv_all = x_test.copy()\n",
    "\n",
    "eps = 8.0\n",
    "nb_iter = 100\n",
    "step = (2.5 * eps) / nb_iter\n",
    "\n",
    "for b in range(n_batches):\n",
    "    X = x_test[b*batch_size:(b+1)*batch_size]\n",
    "    Y = y_cls[b*batch_size:(b+1)*batch_size]\n",
    "    targets = target_logits[b*batch_size:(b+1)*batch_size]\n",
    "\n",
    "    X_adv = X.copy()\n",
    "    \n",
    "    for i in range(nb_iter):\n",
    "        loss_np, grad_np, preds_np = sess.run([loss, grad, preds], feed_dict={x: X_adv, target_logits_ph: targets})\n",
    "\n",
    "        X_adv -= step * np.sign(grad_np)\n",
    "        X_adv = np.clip(X_adv, X-eps, X+eps)\n",
    "        X_adv = np.clip(X_adv, 0, 255)\n",
    "\n",
    "        if i % 10 == 0:\n",
    "            print(b, i, loss_np, np.mean(np.argmax(preds_np, axis=-1) == Y))\n",
    "\n",
    "    X_adv_all[b*batch_size:(b+1)*batch_size] = X_adv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## evaluate the logit-matching attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [05:22<00:00, 40.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of base model: 0.0000\n",
      "Accuracy of full defense: 0.1630\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "_ = do_eval(sess, x, y, preds, X_adv_all, y_test, 'logit-match', True, predictor=predictor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Our final attack: \"logit-matching\" + EOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_cls = np.argmax(y_test, axis=-1)\n",
    "assert y_cls[0] != y_cls[1]\n",
    "logits_clean = batch_eval(sess, [x], [preds], [x_test], batch_size=batch_size)[0]\n",
    "\n",
    "assert np.argmax(logits_clean[0]) == y_cls[0]\n",
    "assert np.argmax(logits_clean[1]) == y_cls[1]\n",
    "target_logits = logits_clean.copy()\n",
    "target_logits[:] = logits_clean[0]\n",
    "target_logits[y_cls == y_cls[0]] = logits_clean[1]\n",
    "\n",
    "target_logits_ph = tf.placeholder(tf.float32, shape=(None, nb_classes))\n",
    "loss = tf.reduce_sum(tf.square(target_logits_ph - preds))\n",
    "grad = tf.gradients(loss, x)[0]\n",
    "\n",
    "n_batches = math.ceil(x_test.shape[0] / batch_size)\n",
    "X_adv_all2 = x_test.copy()\n",
    "\n",
    "for b in range(n_batches):\n",
    "    X = x_test[b*batch_size:(b+1)*batch_size]\n",
    "    Y = y_cls[b*batch_size:(b+1)*batch_size]\n",
    "    targets = target_logits[b*batch_size:(b+1)*batch_size]\n",
    "\n",
    "    X_adv = X.copy()\n",
    "\n",
    "    nb_iter = 100\n",
    "    step = (2.5 * eps) / nb_iter\n",
    "    nb_rand = 40\n",
    "    \n",
    "    # choose the bound for the EOT noise to match the magnitude of the noise used by the defense\n",
    "    if multi_noise:\n",
    "        eps_noise = 0.01 * 255 \n",
    "    else:\n",
    "        eps_noise = 30.0\n",
    "    \n",
    "    for i in range(nb_iter):\n",
    "        loss_np, grad_np, preds_np = sess.run([loss, grad, preds], feed_dict={x: X_adv, target_logits_ph: targets})\n",
    "\n",
    "        for j in range(nb_rand):\n",
    "            \n",
    "            # if the defense uses multiple types of noise, perform EOT over all types\n",
    "            if multi_noise:\n",
    "                if j % 2 == 0:\n",
    "                    noise = np.random.normal(0., 1., size=X_adv.shape)\n",
    "                elif j % 2 == 1:\n",
    "                    noise = np.random.uniform(-1., 1., size=X_adv.shape)\n",
    "                else:\n",
    "                    noise = np.sign(np.random.uniform(-1., 1., size=X_adv.shape))\n",
    "            else:\n",
    "                noise = np.random.normal(0., 1., size=X_adv.shape)\n",
    "                \n",
    "            X_adv_noisy = X_adv + noise * eps_noise\n",
    "            X_adv_noisy = X_adv_noisy.clip(0, 255)\n",
    "            loss_npi, grad_npi, preds_npi = sess.run([loss, grad, preds], feed_dict={x: X_adv_noisy, target_logits_ph: targets})\n",
    "\n",
    "            loss_np += loss_npi\n",
    "            grad_np += grad_npi\n",
    "\n",
    "        loss_np /= (nb_rand + 1)\n",
    "        grad_np /= (nb_rand + 1)\n",
    "\n",
    "        X_adv -= step * np.sign(grad_np)\n",
    "        X_adv = np.clip(X_adv, X-eps, X+eps)\n",
    "        X_adv = np.clip(X_adv, 0, 255)\n",
    "\n",
    "        if i % 10 == 0:\n",
    "            print(b, i, loss_np, np.mean(np.argmax(preds_np, axis=-1) == Y))\n",
    "\n",
    "    X_adv_all2[b*batch_size:(b+1)*batch_size] = X_adv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## evaluate the logit-matching + EOT attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [01:04<00:00, 32.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of base model: 0.0000\n",
      "Accuracy of full defense: 0.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "do_eval(sess, x, y, preds, X_adv_all2, y_test, 'logit-match-eot', True, predictor=predictor)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "odds",
   "language": "python",
   "name": "odds"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
